-
Notifications
You must be signed in to change notification settings - Fork 322
[CPU][float8] Add scaled_embedding_bag kernel #2686
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CPU][float8] Add scaled_embedding_bag kernel #2686
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2686
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit df09264 with merge base 9056c46 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
test/quantization/test_quant_api.py
Outdated
"CPU" not in torch._C._dispatch_dump("torchao::qembeddingbag"), | ||
reason="cpp kernels not built", | ||
) | ||
def test_embeddingbag_cpu(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the test should be added here I think: https://github.com/pytorch/ao/blob/main/test/test_ops.py
This comment was marked as outdated.
This comment was marked as outdated.
❌ 🤖 pytorchbot command failed:
Try |
@pytorchbot label "topic: new feature" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Have you run some benchmark to ensure it's not too slow?
@jerryzh168 Could you help review this pr |
torchao/ops.py
Outdated
@@ -70,6 +70,9 @@ | |||
lib.define( | |||
"da8w4_linear_cpu(Tensor input, Tensor input_scales, Tensor input_qzeros, Tensor weight, Tensor weight_scales, Tensor weight_qzeros, Tensor compensation, Tensor? bias, ScalarType output_dtype) -> Tensor" | |||
) | |||
lib.define( | |||
"qembeddingbag(Tensor qweight, Tensor indices, Tensor offsets, Tensor weight_scale, float o_scale, int mode, bool include_last_offset) -> Tensor" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this the same as https://github.com/pytorch/pytorch/blob/371eacb2ae4ecdabc52ea4634ed21558df2f3bab/aten/src/ATen/native/native_functions.yaml#L2368C1-L2369C1? with the only difference of qweight being float8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jerryzh168 Thanks for reviewing. Yes, I think so, except that the implementation in this PR has limited functionality so far.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This operator is used for inference. So I did not add any parameters related to the gradient, including scale_grad_by_freq, sparse, per_sample_weights, padding_idx.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should add this to pytorch directly if that's the case, float8 is a native dtype in pytorch, so I think it makes most of the sense to just add the functionality there, we can error out in the op if some arg combination is not supported or invalid for float8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Intel's platform has fp8 instructions. When we are ready, we hope to update this kernel based on fp8 instructions. As far as I know, the latest GCC is required. Is it difficult to support in PyTorch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, since this PR adds a quantized version of this op, do you think it better to be added in Torchao rather than in torch core? Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah my question is can this be implemented with extending the embedding_bag op in pytorch and do the scaling in torchao? or will performance be a concern here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a memory bound operator. Repeated reading and writing will lead to significant performance degradation. For example, if we originally need to read and write once(this situation will also occur many times for DLRM), we will need to read and write twice after do the scaling separately, and the performance will be reduced by half.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, sounds good, maybe rename this to _scaled_embedding_bag
to follow these ops: https://github.com/pytorch/pytorch/blob/31a41daff49f2cde941d8b9e35cb2eaeeb606c0d/aten/src/ATen/native/native_functions.yaml#L7135
using _
to indicating it's prototype op since you may want to update the arg list expand hardware coverage etc. later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
test/test_ops.py
Outdated
mode_enum, | ||
include_last_offset, | ||
).to(dtype) | ||
torch.testing.assert_close(refe_out, test_out, atol=0, rtol=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this too strict?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed to 1e-5
@pytorchbot merge |
Merge failedReason: 1 mandatory check(s) are pending/not yet run. The first few are:
Dig deeper by viewing the pending checks on hud |
we just manually merge with the button in torchao |
also is this op built by default? I think ideally it can be optional so it does not impact the normal build. we have seen some errors when some other kernels from prototype feature that breaks the torchao build |
Like other kernel on cpu/*.cpp, it is not built by default and built only with USE_CPU_KERNELS=1. |
Introduced recently in #2686
Introduced recently in #2686
Implemented FP8 QEmbeddingBag on CPU, currently supporting:
include_last_offset=True
mode="sum"
Next steps